// Copyright (C) Kumo inc. and its affiliates.
// Author: Jeff.li lijippy@163.com
// All rights reserved.
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published
// by the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.
//

#pragma once

#include <cassert>
#include <functional>

#include <nebula/compute/util.h>
#include <nebula/compute/util_internal.h>
#include <nebula/util/simd.h>
#include <turbo/utility/status.h>
#include <nebula/types/type_fwd.h>
#include <nebula/version.h>

namespace nebula::compute {

    // SwissTable is a variant of a hash table implementation.
    // This implementation is vectorized, that is: main interface methods take arrays of input
    // values and output arrays of result values.
    //
    // A detailed explanation of this data structure (including concepts such as blocks,
    // slots, stamps) and operations provided by this class is given in the document:
    // nebula/compute/exec/doc/key_map.md.
    //
    class TURBO_EXPORT SwissTable {
        friend class SwissTableMerge;

    public:
        SwissTable() = default;

        ~SwissTable() { cleanup(); }

        using EqualImpl =
                std::function<void(int num_keys, const uint16_t *selection /* may be null */,
                                   const uint32_t *group_ids, uint32_t *out_num_keys_mismatch,
                                   uint16_t *out_selection_mismatch, void *callback_ctx)>;
        using AppendImpl =
                std::function<turbo::Status(int num_keys, const uint16_t *selection, void *callback_ctx)>;

        turbo::Status init(int64_t hardware_flags, MemoryPool *pool, int log_blocks = 0,
                           bool no_hash_array = false);

        void cleanup();

        void early_filter(const int num_keys, const uint32_t *hashes,
                          uint8_t *out_match_bitvector, uint8_t *out_local_slots) const;

        void find(const int num_keys, const uint32_t *hashes, uint8_t *inout_match_bitvector,
                  const uint8_t *local_slots, uint32_t *out_group_ids,
                  util::TempVectorStack *temp_stack, const EqualImpl &equal_impl,
                  void *callback_ctx) const;

        turbo::Status map_new_keys(uint32_t num_ids, uint16_t *ids, const uint32_t *hashes,
                                   uint32_t *group_ids, util::TempVectorStack *temp_stack,
                                   const EqualImpl &equal_impl, const AppendImpl &append_impl,
                                   void *callback_ctx);

        int minibatch_size() const { return 1 << log_minibatch_; }

        uint32_t num_inserted() const { return num_inserted_; }

        int64_t hardware_flags() const { return hardware_flags_; }

        MemoryPool *pool() const { return pool_; }

        int log_blocks() const { return log_blocks_; }

        void num_inserted(uint32_t i) { num_inserted_ = i; }

        uint8_t *blocks() const { return blocks_->mutable_data(); }

        uint32_t *hashes() const {
            return reinterpret_cast<uint32_t *>(hashes_->mutable_data());
        }

        /// \brief Extract group id for a given slot in a given block.
        ///
        inline uint64_t extract_group_id(const uint8_t *block_ptr, int slot,
                                         uint64_t group_id_mask) const;

        inline void insert_into_empty_slot(uint32_t slot_id, uint32_t hash, uint32_t group_id);

        static int num_groupid_bits_from_log_blocks(int log_blocks) {
            int required_bits = log_blocks + 3;
            return required_bits <= 8 ? 8
                                      : required_bits <= 16 ? 16
                                                            : required_bits <= 32 ? 32
                                                                                  : 64;
        }

        // Use 32-bit hash for now
        static constexpr int bits_hash_ = 32;

    private:
        // Lookup helpers

        /// \brief Scan bytes in block in reverse and stop as soon
        /// as a position of interest is found.
        ///
        /// Positions of interest:
        /// a) slot with a matching stamp is encountered,
        /// b) first empty slot is encountered,
        /// c) we reach the end of the block.
        ///
        /// Optionally an index of the first slot to start the search from can be specified.
        /// In this case slots before it will be ignored.
        ///
        /// \param[in] block 8 byte block of hash table
        /// \param[in] stamp 7 bits of hash used as a stamp
        /// \param[in] start_slot Index of the first slot in the block to start search from.  We
        ///            assume that this index always points to a non-empty slot, equivalently
        ///            that it comes before any empty slots.  (Used only by one template
        ///            variant.)
        /// \param[out] out_slot index corresponding to the discovered position of interest (8
        ///            represents end of block).
        /// \param[out] out_match_found an integer flag (0 or 1) indicating if we reached an
        /// empty slot (0) or not (1). Therefore 1 can mean that either actual match was found
        /// (case a) above) or we reached the end of full block (case b) above).
        ///
        template<bool use_start_slot>
        inline void search_block(uint64_t block, int stamp, int start_slot, int *out_slot,
                                 int *out_match_found) const;

        void extract_group_ids(const int num_keys, const uint16_t *optional_selection,
                               const uint32_t *hashes, const uint8_t *local_slots,
                               uint32_t *out_group_ids) const;

        template<typename T, bool use_selection>
        void extract_group_ids_imp(const int num_keys, const uint16_t *selection,
                                   const uint32_t *hashes, const uint8_t *local_slots,
                                   uint32_t *out_group_ids, int elements_offset,
                                   int element_multiplier) const;

        inline uint64_t next_slot_to_visit(uint64_t block_index, int slot,
                                           int match_found) const;

        inline uint64_t num_groups_for_resize() const;

        inline uint64_t wrap_global_slot_id(uint64_t global_slot_id) const;

        void init_slot_ids(const int num_keys, const uint16_t *selection,
                           const uint32_t *hashes, const uint8_t *local_slots,
                           const uint8_t *match_bitvector, uint32_t *out_slot_ids) const;

        void init_slot_ids_for_new_keys(uint32_t num_ids, const uint16_t *ids,
                                        const uint32_t *hashes, uint32_t *slot_ids) const;

        // Quickly filter out keys that have no matches based only on hash value and the
        // corresponding starting 64-bit block of slot status bytes. May return false positives.
        //
        void early_filter_imp(const int num_keys, const uint32_t *hashes,
                              uint8_t *out_match_bitvector, uint8_t *out_local_slots) const;

#if NEBULA_HAVE_RUNTIME_AVX2_SUPPORTED && defined(NEBULA_HAVE_RUNTIME_BMI2)
        // The functions below use BMI2 instructions, be careful before calling!
        int early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* hashes,
                                     uint8_t* out_match_bitvector,
                                     uint8_t* out_local_slots) const;
        int early_filter_imp_avx2_x32(const int num_hashes, const uint32_t* hashes,
                                      uint8_t* out_match_bitvector,
                                      uint8_t* out_local_slots) const;
        int extract_group_ids_avx2(const int num_keys, const uint32_t* hashes,
                                   const uint8_t* local_slots, uint32_t* out_group_ids,
                                   int byte_offset, int byte_multiplier, int byte_size) const;
#endif

        void run_comparisons(const int num_keys, const uint16_t *optional_selection_ids,
                             const uint8_t *optional_selection_bitvector,
                             const uint32_t *groupids, int *out_num_not_equal,
                             uint16_t *out_not_equal_selection, const EqualImpl &equal_impl,
                             void *callback_ctx) const;

        inline bool find_next_stamp_match(const uint32_t hash, const uint32_t in_slot_id,
                                          uint32_t *out_slot_id, uint32_t *out_group_id) const;

        // Slow processing of input keys in the most generic case.
        // Handles inserting new keys.
        // Preexisting keys will be handled correctly, although the intended use is for this
        // call to follow a call to find() method, which would only pass on new keys that were
        // not present in the hash table.
        //
        turbo::Status map_new_keys_helper(const uint32_t *hashes, uint32_t *inout_num_selected,
                                          uint16_t *inout_selection, bool *out_need_resize,
                                          uint32_t *out_group_ids, uint32_t *out_next_slot_ids,
                                          util::TempVectorStack *temp_stack,
                                          const EqualImpl &equal_impl, const AppendImpl &append_impl,
                                          void *callback_ctx);

        // resize small hash tables when 50% full (up to 8KB).
        // resize large hash tables when 75% full.
        turbo::Status grow_double();

        // Number of hash bits stored in slots in a block.
        // The highest bits of hash determine block id.
        // The next set of highest bits is a "stamp" stored in a slot in a block.
        static constexpr int bits_stamp_ = 7;

        // Padding bytes added at the end of buffers for ease of SIMD access
        static constexpr int padding_ = 64;

        int log_minibatch_;
        // Base 2 log of the number of blocks
        int log_blocks_ = 0;
        // Number of keys inserted into hash table
        uint32_t num_inserted_ = 0;

        // Data for blocks.
        // Each block has 8 status bytes for 8 slots, followed by 8 bit packed group ids for
        // these slots. In 8B status word, the order of bytes is reversed. Group ids are in
        // normal order. There is 64B padding at the end.
        //
        // 0 byte - 7 bucket | 1. byte - 6 bucket | ...
        // ---------------------------------------------------
        // |     Empty bit*   |    Empty bit       |
        // ---------------------------------------------------
        // |   7-bit hash    |    7-bit hash      |
        // ---------------------------------------------------
        // * Empty bucket has value 0x80. Non-empty bucket has highest bit set to 0.
        //
        std::shared_ptr<Buffer> blocks_;

        // Array of hashes of values inserted into slots.
        // Undefined if the corresponding slot is empty.
        // There is 64B padding at the end.
        std::shared_ptr<Buffer> hashes_;

        int64_t hardware_flags_;
        MemoryPool *pool_;
    };

    uint64_t SwissTable::extract_group_id(const uint8_t *block_ptr, int slot,
                                          uint64_t group_id_mask) const {
        // Group id values for all 8 slots in the block are bit-packed and follow the status
        // bytes. We assume here that the number of bits is rounded up to 8, 16, 32 or 64. In
        // that case we can extract group id using aligned 64-bit word access.
        int num_group_id_bits = static_cast<int>(NEBULA_POPCOUNT64(group_id_mask));
        assert(num_group_id_bits == 8 || num_group_id_bits == 16 || num_group_id_bits == 32 ||
               num_group_id_bits == 64);

        int bit_offset = slot * num_group_id_bits;
        const uint64_t *group_id_bytes =
                reinterpret_cast<const uint64_t *>(block_ptr) + 1 + (bit_offset >> 6);
        uint64_t group_id = (*group_id_bytes >> (bit_offset & 63)) & group_id_mask;

        return group_id;
    }

    void SwissTable::insert_into_empty_slot(uint32_t slot_id, uint32_t hash,
                                            uint32_t group_id) {
        const uint64_t num_groupid_bits = num_groupid_bits_from_log_blocks(log_blocks_);

        // We assume here that the number of bits is rounded up to 8, 16, 32 or 64.
        // In that case we can insert group id value using aligned 64-bit word access.
        assert(num_groupid_bits == 8 || num_groupid_bits == 16 || num_groupid_bits == 32 ||
               num_groupid_bits == 64);

        const uint64_t num_block_bytes = (8 + num_groupid_bits);
        constexpr uint64_t stamp_mask = 0x7f;

        int start_slot = (slot_id & 7);
        int stamp =
                static_cast<int>((hash >> (bits_hash_ - log_blocks_ - bits_stamp_)) & stamp_mask);
        uint64_t block_id = slot_id >> 3;
        uint8_t *blockbase = blocks_->mutable_data() + num_block_bytes * block_id;

        blockbase[7 - start_slot] = static_cast<uint8_t>(stamp);
        int groupid_bit_offset = static_cast<int>(start_slot * num_groupid_bits);

        // Block status bytes should start at an address aligned to 8 bytes
        assert((reinterpret_cast<uint64_t>(blockbase) & 7) == 0);
        uint64_t *ptr = reinterpret_cast<uint64_t *>(blockbase) + 1 + (groupid_bit_offset >> 6);
        *ptr |= (static_cast<uint64_t>(group_id) << (groupid_bit_offset & 63));
    }

}  // namespace nebula::compute
